Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor contact forces sum in api.ode.system_velocity_dynamics #180

Merged
merged 2 commits into from
Jun 20, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Jun 14, 2024

This PR refactors contact forces sum in system_velocity_dynamics potentially improving readability and performance


📚 Documentation preview 📚: https://jaxsim--180.org.readthedocs.build//180/

src/jaxsim/api/ode.py Outdated Show resolved Hide resolved
@flferretti flferretti force-pushed the flferretti-patch-2 branch 3 times, most recently from 9f7cdf5 to 4eb6d47 Compare June 17, 2024 16:01
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the original desiderata of this PR? Runtime performance, memory footprint, readability, etc?

It seems to me that now readability is much worse than before, and on a quick test I've done, the new logic seems wrong. Did you double check it?

Comment on lines 153 to 160
W_f_Li_terrain = jnp.where(
(
parent_link_index_of_collidable_points[:, None]
== jnp.arange(model.number_of_links())
).any(axis=-1, keepdims=True),
W_f_Ci,
jnp.zeros_like(W_f_Ci),
).sum(axis=0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was both to make the code simpler, as the current version is quite intricate, and reduce the memory footprint of that part of code. If you prefer, I can split the binary mask and the jnp.where(...) as:

Suggested change
W_f_Li_terrain = jnp.where(
(
parent_link_index_of_collidable_points[:, None]
== jnp.arange(model.number_of_links())
).any(axis=-1, keepdims=True),
W_f_Ci,
jnp.zeros_like(W_f_Ci),
).sum(axis=0)
mask = (
parent_link_index_of_collidable_points[:, None]
== jnp.arange(model.number_of_links())
).any(axis=-1, keepdims=True)
W_f_Li_terrain = jnp.where(
mask,
W_f_Ci,
jnp.zeros_like(W_f_Ci),
).sum(axis=0)

@flferretti
Copy link
Collaborator Author

What was the original desiderata of this PR? Runtime performance, memory footprint, readability, etc?

It seems to me that now readability is much worse than before, and on a quick test I've done, the new logic seems wrong. Did you double check it?

I answered you with a comment in the code.

For what regards instead the logic, in which example did you get the error? I did the test simulating spheres in vmap and I had no problem

@diegoferigo
Copy link
Member

For what regards instead the logic, in which example did you get the error? I did the test simulating spheres in vmap and I had no problem

Check the following snippet, I get different results with the new logic. Did I do any mistake?

Code
import jax
import jax.numpy as jnp
import jaxsim.api as js
import resolve_robotics_uri_py
import rod

# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
    uri="model://ergoCubSN001/model.urdf"
)

# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)

# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
)

# Get the parent body of the collidable points.
parent_link_index_of_collidable_points = jnp.array(
    model.kin_dyn_parameters.contact_parameters.body
)

# Create the 6D forces of collidable points.
nc = len(parent_link_index_of_collidable_points)
W_f_Ci = jnp.ones(shape=(nc, 6))

# =========
# Old logic
# =========

f_old = jax.vmap(
    lambda nc: (
        jnp.vstack(jnp.equal(parent_link_index_of_collidable_points, nc).astype(int))
        * W_f_Ci
    ).sum(axis=0)
)(jnp.arange(model.number_of_links()))

# =========
# New logic
# =========

f_new = jnp.where(
    (
        parent_link_index_of_collidable_points[:, None]
        == jnp.arange(model.number_of_links())
    ).any(axis=-1, keepdims=True),
    W_f_Ci,
    jnp.zeros_like(W_f_Ci),
).sum(axis=0)

@flferretti
Copy link
Collaborator Author

I found another way to further simplify the code, if you prefer I could merge the two lines. Ready for review @diegoferigo, thanks!

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better. I also tested on the script of my previous comment and it seems working fine now.

src/jaxsim/api/ode.py Outdated Show resolved Hide resolved
@flferretti flferretti merged commit a963649 into main Jun 20, 2024
29 checks passed
@flferretti flferretti deleted the flferretti-patch-2 branch June 20, 2024 12:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants